{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.nn import GCNConv\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyG构建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2708/2708 [00:00<00:00, 3162.58it/s]\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "生成Graph的节点数据集，使用Cora数据集，以及torch_geometric.data.Data方法，\n",
    "需要参数如下：\n",
    "    x : torch.Tensor, 节点特征矩阵，shape为[num_nodes, num_node_features]\n",
    "    edge_index : LongTensor, Graph的连接矩阵，shape为[2, num_edges]\n",
    "    edge_attr : None, 暂不需要\n",
    "    y : Tensor, 图或节点的标签，shape任意\n",
    "    pos : Tensor, 暂不需要\n",
    "    norm : Tensor, 暂不需要\n",
    "    face : LongTensor, 暂不需要\n",
    "'''\n",
    "content_path = \"./cora/cora.content\"\n",
    "cite_path = \"./cora/cora.cites\"\n",
    "\n",
    "# 读取文本内容\n",
    "with open(content_path, \"r\") as fp:\n",
    "    contents = fp.readlines()\n",
    "with open(cite_path, \"r\") as fp:\n",
    "    cites = fp.readlines()\n",
    "\n",
    "# 边列表\n",
    "cites = list(map(lambda x: x.strip().split(\"\\t\"), cites))\n",
    "\n",
    "# 构建映射字典\n",
    "paper_list, feat_list, label_list = [], [], []\n",
    "for line in tqdm(contents):\n",
    "    tag, *feat, label = line.strip().split(\"\\t\")\n",
    "    paper_list.append(tag)\n",
    "    feat_list.append(np.array(list(map(lambda x: int(x), feat))))\n",
    "    label_list.append(label)\n",
    "# Paper -> Index 字典\n",
    "paper_dict = dict([(key, val) for val, key in enumerate(paper_list)])\n",
    "# Label -> Index 字典\n",
    "label_dict = dict([(key, val) for val, key in enumerate(set(label_list))])\n",
    "# Edge_index构建\n",
    "cites = np.array([[paper_dict[i[0]], paper_dict[i[1]]] for i in cites], \n",
    "                 np.int64).T                                 # (2, edge)\n",
    "cites = np.concatenate((cites, cites[::-1, :]), axis=1)      # (2, 2*edge), 即(2, E)\n",
    "# y 构建\n",
    "y = np.array([label_dict[i] for i in label_list])\n",
    "# Input 构建\n",
    "x = torch.from_numpy(np.array(feat_list, dtype=np.float32))  # [N, Feat_Dim]\n",
    "edge_index = torch.from_numpy(cites)                         # [E, 2]\n",
    "y = torch.from_numpy(y)                               # [N, ]\n",
    "\n",
    "# 构建Data类\n",
    "data = Data(x=x,\n",
    "            edge_index=edge_index,\n",
    "            y=y)\n",
    "# 分割数据集\n",
    "data.train_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)\n",
    "data.train_mask[:data.num_nodes - 1000] = 1                  # 1700左右training\n",
    "data.val_mask = None                                         # 0valid\n",
    "data.test_mask = torch.zeros(data.num_nodes, dtype=torch.uint8)\n",
    "data.test_mask[data.num_nodes - 500:] = 1                    # 500test\n",
    "data.num_classes = len(label_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 打印数据信息"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(edge_index=[2, 10858], num_classes=7, test_mask=[2708], train_mask=[2708], x=[2708, 1433], y=[2708])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********************Data Info********************\n",
      "==> Is undirected graph : True\n",
      "==> Number of edges : 10858/2=5429\n",
      "==> Number of nodes : 2708\n",
      "==> Node feature dim : 1433\n",
      "==> Number of training nodes : 1708\n",
      "==> Number of testing nodes : 500\n",
      "==> Number of classes : 7\n"
     ]
    }
   ],
   "source": [
    "print(\"{}Data Info{}\".format(\"*\"*20, \"*\"*20))\n",
    "print(\"==> Is undirected graph : {}\".format(data.is_undirected()))\n",
    "print(\"==> Number of edges : {}/2={}\".format(data.num_edges, int(data.num_edges/2)))\n",
    "print(\"==> Number of nodes : {}\".format(data.num_nodes))\n",
    "print(\"==> Node feature dim : {}\".format(data.num_node_features))\n",
    "print(\"==> Number of training nodes : {}\".format(data.train_mask.sum().item()))\n",
    "print(\"==> Number of testing nodes : {}\".format(data.test_mask.sum().item()))\n",
    "print(\"==> Number of classes : {}\".format(data.num_classes))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "********************Train Data anf Test Data Info********************\n",
      "==> Label list : \n",
      "    (0, 'Probabilistic_Methods')\n",
      "    (1, 'Reinforcement_Learning')\n",
      "    (2, 'Genetic_Algorithms')\n",
      "    (3, 'Case_Based')\n",
      "    (4, 'Rule_Learning')\n",
      "    (5, 'Theory')\n",
      "    (6, 'Neural_Networks')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:3: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at  ..\\aten\\src\\ATen/native/IndexingUtils.h:25.)\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n",
      "D:\\Anaconda3\\lib\\site-packages\\ipykernel_launcher.py:12: UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead. (Triggered internally at  ..\\aten\\src\\ATen/native/IndexingUtils.h:25.)\n",
      "  if sys.path[0] == '':\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0.5,1,'Test Data Statics')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xu8VXWd//HXm4vcFFBBQi6iSF4KLwylaWZK05ialpbplJdCyZmyZrQM+5VO2aRjpqnNpIyWUmaQ1U9Ep8ZQM2ZExbt4A1HhBAmIYKaoyGf+WN8Dm+M6hw17r733Ofv9fDz2Y6/13Wt/1+fAZ+/PXrfvUkRgZmbWVrd6B2BmZo3JBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwtEDUjqLukVSSPrHUujkvTfkj5d7zjMqkHS7ZI+Ve84KuUCkSN9mbc+1kl6rWR+s7/EIuKtiNg6IhZtQSy7SoqS9f9Z0s2SJmxGH6dKunNz113yfkn6hqTnUgwtkq4veX22pFM2o7/vSLq2tC0iPhwR17fzFquyaud4Sb9zJH2mg9d3z8nnGZIO2Yx1nC7p9xXEKEnnleTzYklTy/0bcvq7UNLVpW0RcWhETNvSGBuFC0SO9GW+dURsDSwCPlrS9rYvMUk9ahUTsC9wOzBjc5K4Qp8DjgcOTTG8B7izRuu2AmxujlfZW23y+S5gpqTjC15vq0nAscAhKYb9UgzWVkT40cEDeA74UJu27wDTgBuAvwCnAO8D5gCrgKXA5UDPtHwPIIBRaf5n6fX/Su+/G9i5nfXvmv03va19MrAEUJr/BrAw9TcPOCq1jwXWAG8BrwArUvtRwENp+UXANzv4N7gSuLid1/4t9b0m9f+D1P5DoAV4GbgPOCC1Hwm8AbyZlr8/tc8GTinp9/PAkym+x4C9U/vX09/9cnr9g/XOkc7+aCfHuwPfTDm1ArgeGJhe6wf8AliZ8v0eYFvg+21y4fs569odWJvT/g1gccn8ucCzJf//R6T2fVP/a9M6/pzaPw48nPLieeDrHfy9VwMXtvNa7t8A/Kgkn+8F9k/tH2uTz/em9jnAZ0r6/ceSfH4UGJvav0n2ffEy8ARwUL3zYaN/j3oH0OiPdj4830lJ8VGyrbA+ZL+q9yMrBrsATwNfTMvnFYgVwHigJ1mx+Vk762+vQLwz9TkmzR8HDE3x/H1K1iHptVOBO9u8/1Dg3Wn5vVM8R7YTwynAi8BXgL8Burd5faMv99R2IrBd+tu/BvwJ6FXy73dte30AJwCL07qU/tYRwLvSh/8dabmdgV3qnSOd/dFOjk8G/gjsCPQGrgV+kl77MnBjyvseKff7pdc2+mLMWVd7BWLPlM87p/lPleTziWRfrIPSa6cDv2/z/gkpP7oB48iK12HtxHAqsBw4My3bNp/f9jcAJ5EVwZ7A/0v52foD8ELg6vb6SPE/T1bcBOwGDE+fu4XAkNS+C+38UKzXw7uYttzsiLg5ItZFxGsRcV9E3BMRayNiITAFOLiD998YEXMj4k2yX2f7bOb6l6Tn7QAiYnpELE3x/JzsQz++vTdHxO0R8Vha/mGyX4S58UbEtcA/AR8h2xRfJukrHQUXET+NiJURsRa4COhPVuzKcSrZL7z7I/N0RCwm+9XYG3iXpB4R8Wz6t7bq+zwwOSKWRMQa4FvApySJ7NfyYGB0yvf7IuKvFa6vbT5PK8nnn5L9wPib9t4cEbMiYl5a/gFgOu1//q4h+7HzUbIfJi9I+ueOgouIqRHxUvq8fhfYnuwLvRynAt+NiAdTPj8VES1k+dyHrDh2j4iFEfFsmX3WhAvElltcOpMOvt2SDrq9DHwbGNTB+/9cMv0qsPVmrn9Yel6Z1n+KpIclrZK0iuyXWrvrl/Q+SXdKWi5pNVkSt7t8+sKfAAwEvgBc0NGBcklnS3oy9f0S2W6Jjv49So0AnsmJ4SngLLJ/22WSbpD0jjL7tDKlIjACuLUknx4k+77YnuwL9g/AjemEhe9K6l7hatvm80RJj5Ssf1c6zucDJf2hJJ9PaW/59CV9XUQcQpbPXwIuktTuDzpJ50h6qiSfe3cUTxvt5fM8si21fyXL5+slDSmzz5pwgdhybYfBvYpsX+muEdGfbB+qClz/x8mKzAJJu5DtI/0HYPuIGEi2v7N1/XlD9v4C+BUwIiIGkO2X3WS8EfFmRPyC7DjHu/P6T2eknEl2IHAg2ab5K5uIp9RiYHQ76/9ZRBxItnupO3DBpmK2zRPZfpE/kZ2UMLDk0TsiVkTE6xFxbkTsDnwA+CTZSQyw6f/b9nwcaImIZyW9E7iC7GDydimfF9Bx/kwn21Xbms/XUl4+v5G2uJ+i/Xz+W+CMFONAsq2c1zYRT6mO8vm6iDiAbGukN9nu14bhAlE92wCrgb9K2oNsE73qJA2R9CWyg3pfSx/mrcmSdHm2iE4l24Jo9QIwXFLPNvGujIg1kvZnwwc8b52fk3S4pG0kdZN0BNl+1HtL+i/d3N6GbPN5Bdk+238h24IojWdU+qWa52rgbEn7plMSx0gaIWkPSYdI6kX2AX2N7ICiVd+VwIWSRgBI2kHSR9P0hyTtKakb2cHVtWz4f2ibCx2S9I60e+ccsl/TkOXzOrJ87ibpdDbePfkCMKI1n1MebQ28mPL5ALKi1d46T5V0mKStUz4flfrvKJ/fTPFsRbYF27tNPDtvIp8nS9o75fM7JQ1P/4YHN3I+u0BUz1nAyWQH064i+zVTNel87VeAR4C/A46JiKkAEfEI2VlR95KdEbE72ZklrW4D5pPta23dtfUPZLuJ/kJ2ZtD0Dlb/MuksE7LN6+8CkyLi7vT6D4AT0u6AS4Bbgd+ndT6X3r+0pL9pZB+0lZLupY2IuIHs7Khp6b2/JtsK6UV2PGMF2dbTtikuq76LyP4Pb0858r9kB3Qh2x10ExvOMLqVDflzKXCSpJckXdRO360Xjv6V7MyjCcDRkU6vTccQrgTmkuXNzmm61W/J8mqZpJb0I+l04OIU69nALzv42/4CnEd2VtJLwPnAxIi4r52/4WayY2/PsOGsruUl/f0C6EuWz//bdmXpGMolZAf2X07PA8mOP3w/9beUrMid20HcNdd6iqSZmdlGvAVhZma5Ci0QkgZKujGdzfJEOnNmO0m3SZqfnrdNy0rS5ZIWpLMXxm2qf7N6cW5bMyh6C+Iy4LfpbIe9ya4UnAzMiogxwCw2HJj6CDAmPSaRnZVj1qic29blFXYMQlJ/sgNQu0TJSiQ9RTY8wlJJQ8mu8N1N0lVp+oa2yxUSoNkWcm5bsyhykLldyI70/0TS3sD9ZJfoD2n9YKQP0g5p+WFsfPFZS2rb6EMkaRLZrzD69ev3N7vvXno2p1n13H///SsiYnDOS85t69Q6yO2NFFkgepCdFndGRNwj6TI2bHLnyTuH+G2bNxExhWwYC8aPHx9z585925vMqkHS8+285Ny2Tq2D3N5IkccgWsiujGw9H/9Gsg/VC2nzm/S8rGT5ESXvH86G8VnMGolz25pCYQUiIv4MLJa0W2qaADwOzCC7oIz0fFOankF2cYqUXdm72vtorRE5t61ZFH2jmzOA6yVtRXYF4mfJitJ0SRPJ7kPQekn8rcDhZGOuvJqWNWtUzm3r8gotEBHxEPlDTr9tFNB0NsgXiozHrFqc29YMfCW1mZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy9Wj3gGYVWLU5Fsqev9zFx5RpUjMuh5vQZiZWS4XCDMzy+UCYWZmuVwgzMwsV6EFQtJzkh6V9JCkualtO0m3SZqfnrdN7ZJ0uaQFkh6RNK7I2Mwq4dy2ZlCLLYhDImKfiBif5icDsyJiDDArzQN8BBiTHpOAH9UgNrNKOLetS6vHLqajgevS9HXAx0rap0ZmDjBQ0tA6xGe2pZzb1qUUXSAC+G9J90ualNqGRMRSgPS8Q2ofBiwueW9LajNrRM5t6/KKvlDuwIhYImkH4DZJT3awrHLa4m0LZR/GSQAjR46sTpRmm8+5bV1eoVsQEbEkPS8DfgO8F3ihdfM6PS9Li7cAI0rePhxYktPnlIgYHxHjBw8eXGT4Zu1yblszKKxASOonaZvWaeDDwGPADODktNjJwE1pegZwUjrjY39gdevmulkjcW5bsyhyF9MQ4DeSWtfz84j4raT7gOmSJgKLgE+m5W8FDgcWAK8Cny0wNrNKOLetKRRWICJiIbB3TvuLwISc9gC+UFQ8ZtXi3LZm4SupzcwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpar8AIhqbukByXNTPM7S7pH0nxJ0yRtldp7pfkF6fVRRcdmtqWc19YMarEF8WXgiZL5fwMujYgxwEvAxNQ+EXgpInYFLk3LmTUq57V1eYUWCEnDgSOAq9O8gEOBG9Mi1wEfS9NHp3nS6xPS8mYNxXltzaLoLYgfAGcD69L89sCqiFib5luAYWl6GLAYIL2+Oi2/EUmTJM2VNHf58uVFxm7WnqrnNTi3rfEUViAkHQksi4j7S5tzFo0yXtvQEDElIsZHxPjBgwdXIVKz8hWV1+DctsbTo8C+DwSOknQ40BvoT/bLa6CkHunX1HBgSVq+BRgBtEjqAQwAVhYYn9mWcF5b0yhsCyIizomI4RExCjgeuD0iPg3cAXwiLXYycFOanpHmSa/fHhG5v7TM6sV5bc2kHtdBfA04U9ICsn2x16T2a4DtU/uZwOQ6xGa2pZzX1uUUuYtpvYi4E7gzTS8E3puzzBrgk7WIx6wanNfW1ZVdICT1L10+Irwf1bqEl19+mbVr166f32677eoYjVnj2GSBkPR54NvAa2w4+yKAXQqMy6xwf3novxgy5HP06dOH1ksTJLFw4cI6R2bWGMrZgvgK8K6IWFF0MGa19PK9v2bp0/MYNGhQvUMxa0jlHKR+Bni16EDMaq3HwKH07du33mGYNaxytiDOAf5X0j3A662NEfGlwqIyq4FtDz6ZAw44gP32249evXqtb7/88svrGJVZ4yinQFwF3A48yoahBcw6vRd/90OOP/5Ixo4dS7duHvnerK1yCsTaiDiz8EjMakzqziWXXFLvMMwaVjk/m+5Ig4gNlbRd66PwyMwK1munvZgyZQpLly5l5cqV6x9mlilnC+Lv0/M5JW0+zdU6vb8+/gcuuOBeLrjggvVtPs3VbINNFoiI2LkWgZjV2vDTr+HZC4+odxhmDaucC+VOymuPiKnVD8esdl55bBZTp774tvaTTspNebOmU84upveUTPcGJgAPAC4Q1qm9sXQ+9933JgBr1qxh1qxZjBs3zgXCLClnF9MZpfOSBgA/LSwisxrZ7m9P54qSXUyrV6/mxBNPrGNEZo1lS07+fhUYU+1AzOqtb9++zJ8/v95hmDWMco5B3MyGQfq6AXsC04sMyqwWlt34LY56/CoA1q1bx+OPP85xxx1X56jMGkc5xyAuLpleCzwfES0FxWNWM/3fewxnff59APTo0YOddtqJ4cOH1zkqs8ZRzjGIP9QiELNa6z1yLAcffHC9wzBrWO0WCEnPsmHXUlsREaOLCcmsWC1XTkRpepfpG4/mKolnnnmm9kGZNaCOtiDGt5nvBhxHdn+IBwuLyKxgQ0++dP30fed+mHXr1jF9+nQuvvhi9t133zpGZtZY2i0QEfEigKRuwInAV4GHgCMi4vHahGdWfd379AcgYh0zZ87ke9/7Hvvssw+33HILe+65Z52jM2scHe1i6gl8DvhnYDZwdER429s6vXhrLa88ehsv33cTf/z433HTTTcxerT3mJq11dEupmfJzlr6AbAI2FvS3q0vRsSvC47NrBB/unIidOtO//FHc/jhh/Hwww/z8MMPr3/9mGOOqWN0Zo2jowLxe7KD1HunR6kAXCCsU+o9ah9AvLHsWW6++eaNXpPkAmGWdHQM4pQaxmFWM4OO+Of10z/xaK5m7fJ9Fs3MLJcLhJmZ5XKBMDOzXJssEJL6SvqmpP9M82MkHVnG+3pLulfSw5LmSfpWat9Z0j2S5kuaJmmr1N4rzS9Ir4+q7E8z69i6N9dw/vnnc9pppwEwf/58Zs6cucn3ObetWZSzBfET4HXgfWm+BfhOGe97HTg0IvYG9gEOk7Q/8G/ApRExBngJmJiWnwi8FBG7Apem5cwK8+Ktl9GrVy/uvvtuAIYPH843vvGNct7q3LamUE6BGB0RFwFvAkTEa7B+KJt2ReaVNNszPQI4FLgxtV8HfCxNH53mSa9PkLTJ9ZhtqbWrlnL22WfTs2dPAPr06UNEe8OPbeDctmZRToF4Q1If0sB9kkaT/YLaJEndJT0ELANuA54BVkXE2rRICzAsTQ8DFgOk11cD2+f0OUnSXElzly9fXk4YZrnUrQevvfYard/VzzzzDL169Srvvc5tawLlFIjzgN8CIyRdD8wCzi6n84h4KyL2AYYD7wX2yFssPef9onrbz7mImBIR4yNi/ODBg8sJwyzXgPd/msMOO4zFixfz6U9/mgkTJnDRRReV9V7ntjWDcu4HcZukB4D9yRL9yxGxYnNWEhGrJN2Z+hgoqUf6JTUcWJIWawFGAC2SegADgJWbsx6zzdFn53359X/8A3PmzCEiuOyyyxg0aNBm9eHctq6s3S0ISeNaH8BOwFKyhB+Z2jokabCkgWm6D/Ah4AngDuATabGTgZvS9Iw0T3r99ihnh7DZZnr9zwvWP55//nmGDh3KjjvuyKJFi3jggQc2+X7ntjWLjrYgvt/Ba60H5DoyFLhOUneyQjQ9ImZKehz4haTvkN1X4pq0/DXATyUtIPt1dXw5f4DZ5nrpjmvWT5/1/P/f6DVJ3H777ZvqwrltTaGjsZgOqaTjiHgEeNvdVyJiIdk+27bta4BPVrJOs3K844QL1k/fsQVjMTm3rVls8hiEpN7APwLvJ9ty+CNwZUp6s04r1r7BJZdcwuzZs5HEQQcdxOmnn07v3r3rHZpZQyjnLKapwLuAK4AfAnsCPy0yKLNaWDHzEubNm8cZZ5zBF7/4RR5//HFOPPHEeodl1jA2uQUB7JauGG11h6SH213arJN4c2UL11zzx/XzhxxyCHvv3fbWJ2bNq5wtiAfTMAIASNoP+J/iQjKrja2GjGbOnDnr5++55x4OPPDAOkZk1ljK2YLYDzhJ0qI0PxJ4QtKjZKMO7FVYdGYFen3JUxxwwAGMHDkSgEWLFrHHHnswduzY9VdXmzWzcgrEYYVHYVYHQ477Nv8zuf2ztUeNGlW7YMwaUDlXUj8vaVuyK0F7lLRv+ooiswbWY8AO9O/fn8WLF7N27dr17ePGbfI6ULOmUM5prucDp5ANRtZ69Wc5F8qZNbRVd/2Uva4/ndGjR6/fpVTmhXJmTaGcXUzHkQ35/UbRwZjV0l+fms2ypQvZaqut6h2KWUMq5yymx4CBRQdiVmtbDdqJVatW1TsMs4ZVzhbEBWSnuj5GyX0gIuKowqIyq4H++3+Sfffdl3e/+90b3QdixowZdYzKrHGUUyCuI7tF4qPAumLDMaudF2+5lIvP/Rpjx46lW7dyNqbNmks5BWJFRFxeeCRmNdatb3++9KUv1TsMs4ZVToG4X9IFZGPal+5i8mmu1qltNWQ055xzDkcdddRGu5h8mqtZppwC0Tqs8f4lbQ1/muuoybdU9P7ntmAYaOtc3li2kDlzVm803IZPczXboJwL5Sq6L4S1z0Wsvt5xwgVbdD8Is2ZRzhYEko4gG/J7/UD5EfHtooIyq5VbbrmFefPmsWbNhtubnHvuuXWMyKxxbPLUDUlXAp8CzgBEdmesnQqOy6xwL/7uh0ybNo0rrriCiOCXv/wlzz//fL3DMmsY5WxBHBARe0l6JCK+Jen7wK+LDsy6pkbarfb6n55k6oP/xV577cV5553HWWedxTHHHFO1/s06u3JO/n4tPb8qaUfgTWDn4kIyqw31yIbY6Nu3L0uWLKFnz548++yzdY7KrHGUswUxU9JA4HvAA2RnMP1noVGZ1UCf0e9h1apVfPWrX2XcuHFI4rTTTqt3WGYNo5yzmM5Pk7+SNBPoHRGriw3LrHgDDzyBgQMHcuyxx3LkkUeyZs0aBgwYUO+wzBpGu7uYJL1H0jtK5k8CpgPnS9quFsGZFeH1pU/z1isvrZ+fOnUqxx13HN/85jdZuXJlHSMzaywdHYO4CngDQNIHgAuBqcBqYErxoZkVY+Xv/h26ZxvPd911F5MnT+akk05iwIABTJo0qc7RmTWOjnYxdY+I1p9TnwKmRMSvyHY1PVR8aGbFiHVv0b3PNgBMmzaNSZMmceyxx3Lssceyzz771Dk6s8bR0RZEd0mtBWQCUDr+QFkX2Jk1pFhHrHsLgFmzZnHooRtGjSm99ahZs+voi/4G4A+SVpCd6vpHAEm7ku1mMuuU+u1xMC/8fDLd+vRnt637cNBBBwGwYMECH6Q2K9HuFkRE/CtwFnAt8P6IiJL3nLGpjiWNkHSHpCckzZP05dS+naTbJM1Pz9umdkm6XNICSY9I8pCaVogBB3yKbQ+ZyNZjJzB79uz196Net24dV1xxxSbf79y2ZtHhrqKImJPT9nSZfa8FzoqIByRtQzZs+G3AKcCsiLhQ0mRgMvA14CPAmPTYD/hRejarul7DdgegX79+69ve+c53lvt257Y1hcJuoxURS1vvGRERfwGeAIYBR5PdpY70/LE0fTQwNTJzgIGShhYVn9mWcm5bs6jJwWZJo8juK3EPMCQilkL2QZO0Q1psGLC45G0tqW1pm74mAZMARo4cWWjcnU0jjXPULJzb1pUVfiNeSVsDvwL+KSJe7mjRnLZ4W0PElIgYHxHjBw8eXK0wzTabc9u6ukILhKSeZB+g6yOidQTYF1o3r9PzstTeAowoeftwYEmR8ZltKee2NYPCCoSyU0OuAZ6IiEtKXpoBnJymTwZuKmk/KZ3xsT+wunVz3ayROLetWRR5DOJA4ETg0ZIrr79ONmTHdEkTgUVkNyACuBU4HFgAvAp8tsDYzCrh3LamUFiBiIjZ5O97hezK7LbLB/CFouIxqxbntjWLwg9Sm5lZ5+QCYWZmuVwgzMwslwuEmZnlcoEwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuWpyP4iuoNJ7LYDvt2BmnYsLhJlZA2qEG4B5F5OZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeVygTAzs1wuEGZmlssFwszMcrlAmJlZLhcIMzPL5QJhZma5CisQkn4saZmkx0ratpN0m6T56Xnb1C5Jl0taIOkRSeOKisusUs5taxZF3jDoWuCHwNSStsnArIi4UNLkNP814CPAmPTYD/hRerY68l302nUtzm1rAoVtQUTEXcDKNs1HA9el6euAj5W0T43MHGCgpKFFxWZWCee2NYtaH4MYEhFLAdLzDql9GLC4ZLmW1GbWWTi3rctplIPUymmL3AWlSZLmSpq7fPnygsMyq5hz2zqtWheIF1o3r9PzstTeAowoWW44sCSvg4iYEhHjI2L84MGDCw3WbDM4t63LqXWBmAGcnKZPBm4qaT8pnfGxP7C6dXPdrJNwbluXU9hZTJJuAD4IDJLUApwHXAhMlzQRWAR8Mi1+K3A4sAB4FfhsUXGZVcq5bc2isAIRESe089KEnGUD+EJRsZhVk3PbmkWR10GYWQEqvT6li16bYgVwgTCzqvIFll1Ho5zmamZmDcYFwszMcrlAmJlZLhcIMzPL5QJhZma5XCDMzCyXC4SZmeVygTAzs1wuEGZmlstXUptZw2v04UW66tXj3oIwM7NcLhBmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuXyaq5k1nUY/bbZReAvCzMxyeQvCrMn517S1x1sQZmaWywXCzMxyuUCYmVkuFwgzM8vlAmFmZrlcIMzMLJcLhJmZ5WqoAiHpMElPSVogaXK94zGrFue2dUYNUyAkdQf+HfgIsCdwgqQ96xuVWeWc29ZZNUyBAN4LLIiIhRHxBvAL4Og6x2RWDc5t65QUEfWOAQBJnwAOi4hT0/yJwH4R8cU2y00CJqXZ3YCntnCVg4AVW/jeWvRXRJ/NGGMl/e0UEYMrDcC5XXh/RfTZ6P1V2mdZud1IYzEpp+1t1SsipgBTKl6ZNDcixlfaT1H9FdFnM8ZYxN+8JWHktDVtbjsPG7fPthppF1MLMKJkfjiwpE6xmFWTc9s6pUYqEPcBYyTtLGkr4HhgRp1jMqsG57Z1Sg2ziyki1kr6IvA7oDvw44iYV+AqK96UL7i/IvpsxhiL+Js3i3O78P6K6LPR+yuqz400zEFqMzNrLI20i8nMzBqIC4SZmeVqygJRzWEPJP1Y0jJJj1UpthGS7pD0hKR5kr5chT57S7pX0sOpz29VKdbukh6UNLMKfT0n6VFJD0maW6X4Bkq6UdKT6d/zfdXot1FVeziPRs/tzpDXqb+q5nZN8zoimupBdpDwGWAXYCvgYWDPCvr7ADAOeKxK8Q0FxqXpbYCnK4kv9SNg6zTdE7gH2L8KsZ4J/ByYWYW+ngMGVfn/+jrg1DS9FTCw1vlWq0e18zr12dC53RnyOvVX1dyuZV434xZEVYc9iIi7gJXVCi4ilkbEA2n6L8ATwLAK+4yIeCXN9kyPis5OkDQcOAK4upJ+iiKpP9kX3DUAEfFGRKyqb1SFqvpwHo2e287r4vO6GQvEMGBxyXwLFX4BF0XSKGBfsl9GlfbVXdJDwDLgtoiotM8fAGcD6yqNLQngvyXdn4acqNQuwHLgJ2l3wdWS+lWh30bVafIaqpfbnSCwRb6GAAAEDElEQVSvobq5XdO8bsYCUdawB/UmaWvgV8A/RcTLlfYXEW9FxD5kV/G+V9K7K4jtSGBZRNxfaVwlDoyIcWQjnn5B0gcq7K8H2e6RH0XEvsBfga48zHanyGuobm53gryG6uZ2TfO6GQtEww97IKkn2Qfo+oj4dTX7TpujdwKHVdDNgcBRkp4j25VxqKSfVRjXkvS8DPgN2S6TSrQALSW/KG8k+2B1VQ2f11BcbjdqXqfYqpnbNc3rZiwQDT3sgSSR7V98IiIuqVKfgyUNTNN9gA8BT25pfxFxTkQMj4hRZP9+t0fEZyqIr5+kbVqngQ8DFZ05ExF/BhZL2i01TQAer6TPBtfQeQ3Vz+1Gz+sUV1Vzu9Z53TBDbdRKVHnYA0k3AB8EBklqAc6LiGsqCPFA4ETg0bRvFeDrEXFrBX0OBa5TduOabsD0iKjKKXxVMgT4Tfb9QQ/g5xHx2yr0ewZwffrCXAh8tgp9NqRq5zV0itxu9LyGYnK7ZnntoTbMzCxXM+5iMjOzMrhAmJlZLhcIMzPL5QJhZma5XCDMzCyXC0QdSXpl00utX/ZfJH2lGv1vznrT8h+s1siW1hyc212DC4SZmeVygWgwkj4q6Z40ENfvJQ0peXlvSbdLmi/ptJL3fFXSfZIe2Zwx8dOvpztLxpa/Pl3t2npvgSclzQaOKXlPP2X3CbgvxXh0aj9T0o/T9FhJj0nqW+m/h3Udzu1OqKhxxP0oa1z3V3LatmXDBYynAt9P0/9CNsZ/H2AQ2cidO5Jduj+FbLC2bsBM4APt9V/aTnaV7GqycXu6AXcD7wd6p/7HpH6nk8bGB74LfCZNDyQb079fev9dwMeBuWQDlNX939iP+jyc213j0XRDbXQCw4FpkoaS3Qzk2ZLXboqI14DXJN1BNujX+8k+SA+mZbYmS/67ylzfvRHRApCGPxgFvAI8GxHzU/vPgNZhij9MNqBZ6z7j3sDIiHhC0inAI8BVEfE/m/VXWzNwbncyLhCN5wrgkoiYIemDZL+uWrUdFyXIfgVdEBFXbeH6Xi+ZfosNOdHeGCwCjo2Ip3JeG0P2AdxxC2Oxrs253cn4GETjGQD8KU2f3Oa1o5Xdh3d7sk3o+8gGZ/ucsjH2kTRM0g4VxvAksLOk0Wn+hJLXfgecUbI/d9/0PAC4jOxuV9tL+kSFMVjX49zuZLwFUV990yiZrS4h+1X1S0l/AuYAO5e8fi9wCzASOD+yceaXSNoDuDvl9SvAZ8jusLVFImKNsjtf3SJpBTAbaL0Ry/lkd916JH2QngOOBC4F/iMinpY0EbhD0l2RjYFvzce53QV4NFczM8vlXUxmZpbLBcLMzHK5QJiZWS4XCDMzy+UCYWZmuVwgzMwslwuEmZnl+j8FGGFr9zjK1QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(\"{}Train Data anf Test Data Info{}\".format(\"*\"*20, \"*\"*20))\n",
    "print(\"==> Label list : \"+(\"\\n    {}\"*7).format(*[(i,j) for j,i in label_dict.items()]))\n",
    "inds, nums = np.unique(y[data.train_mask].numpy(), return_counts=True)\n",
    "plt.figure(1)\n",
    "plt.subplot(121)\n",
    "plt.bar(x=inds, height=nums, width=0.8, bottom=0, align='center')\n",
    "plt.xticks(range(7))\n",
    "plt.xlabel(\"Label Index\")\n",
    "plt.ylabel(\"Sample Num\")\n",
    "plt.ylim((0, 600))\n",
    "plt.title(\"Train Data Statics\")\n",
    "inds, nums = np.unique(y[data.test_mask].numpy(), return_counts=True)\n",
    "plt.subplot(122)\n",
    "plt.bar(x=inds, height=nums, width=0.8, bottom=0, align='center')\n",
    "plt.xticks(range(7))\n",
    "plt.xlabel(\"Label Index\")\n",
    "plt.ylabel(\"Sample Num\")\n",
    "plt.ylim((0, 600))\n",
    "plt.title(\"Test Data Statics\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### PyG构建GCN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "使用官方教程上的例子，利用MessagePassing类来构建GCN层。\n",
    "初始化阶段：\n",
    "Input :\n",
    "    in_channels : (int)输入的节点特征维度\n",
    "    out_channels : (int)节点输出的特征维度\n",
    "Output :\n",
    "    None\n",
    "\n",
    "forward阶段：\n",
    "Input :\n",
    "    x : (Tensor)输入的节点特征矩阵，shape(N, in_channels)\n",
    "    edge_index : (LongTensor)输入的边矩阵，shape(2, E)\n",
    "Output :\n",
    "    out : (Tensor)输出层的节点logits，shape(N, num_class)\n",
    "'''\n",
    "class GCNConv(MessagePassing):\n",
    "    def __init__(self, in_channels, out_channels):\n",
    "        super(GCNConv, self).__init__(aggr='add')  # \"Add\" aggregation.\n",
    "        self.lin = torch.nn.Linear(in_channels, out_channels)\n",
    "\n",
    "    def forward(self, x, edge_index):\n",
    "        # x has shape [N, in_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "\n",
    "        # Step 1: Add self-loops to the adjacency matrix.\n",
    "        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n",
    "\n",
    "        # Step 2: Linearly transform node feature matrix.\n",
    "        x = self.lin(x)  # (N, in_channels) -> (N, out_channels)\n",
    "        # Step 3-5: Start propagating messages.\n",
    "        \n",
    "        # Step 3: Normalize node features.\n",
    "        row, col = edge_index  # [E,], [E,]\n",
    "        deg = degree(row, x.size(0), dtype=x.dtype)  # [N, ]\n",
    "        deg_inv_sqrt = deg.pow(-0.5)   # [N, ]\n",
    "        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # [E, ]\n",
    "        \n",
    "        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, norm=norm)\n",
    "\n",
    "    def message(self, x_j, norm):\n",
    "        # x_j has shape [E, out_channels]\n",
    "        # edge_index has shape [2, E]\n",
    "\n",
    "        # Step 3: Normalize node features.\n",
    "        # row, col = edge_index  # [E,], [E,]\n",
    "        # deg = degree(row, size[0], dtype=x_j.dtype)  # [N, ]\n",
    "        # deg_inv_sqrt = deg.pow(-0.5)   # [N, ]\n",
    "        # norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]  # [E, ]\n",
    "\n",
    "        return norm.view(-1, 1) * x_j\n",
    "\n",
    "    def update(self, aggr_out):\n",
    "        # aggr_out has shape [N, out_channels]\n",
    "\n",
    "        # Step 5: Return new node embeddings.\n",
    "        return aggr_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "构建模型，使用两层GCN，第一层GCN使得节点矩阵\n",
    "        (N, in_channel) -> (N, 16)\n",
    "第二层GCN使得节点矩阵\n",
    "        (N, 16) -> (N, num_class)\n",
    "激活函数使用relu函数，网络最后对节点的各类别score使用softmax归一化，\n",
    "返回归一化后的Tensor。\n",
    "'''\n",
    "class Net(torch.nn.Module):\n",
    "    def __init__(self, feat_dim, num_class):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = GCNConv(feat_dim, 16)\n",
    "        self.conv2 = GCNConv(16, num_class)\n",
    "\n",
    "    def forward(self, data):\n",
    "        x, edge_index = data.x, data.edge_index\n",
    "\n",
    "        x = self.conv1(x, edge_index)\n",
    "        x = F.relu(x)\n",
    "        x = F.dropout(x, training=self.training)\n",
    "        x = self.conv2(x, edge_index)\n",
    "\n",
    "        return F.log_softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0/200] Loss 1.9322, train acc 0.2717\n",
      "[Epoch 1/200] Loss 1.8872, train acc 0.3091\n",
      "[Epoch 2/200] Loss 1.8363, train acc 0.3443\n",
      "[Epoch 3/200] Loss 1.7857, train acc 0.3858\n",
      "[Epoch 4/200] Loss 1.7171, train acc 0.4491\n",
      "[Epoch 5/200] Loss 1.6643, train acc 0.4725\n",
      "[Epoch 6/200] Loss 1.5873, train acc 0.5135\n",
      "[Epoch 7/200] Loss 1.5310, train acc 0.5386\n",
      "[Epoch 8/200] Loss 1.4638, train acc 0.5650\n",
      "[Epoch 9/200] Loss 1.3718, train acc 0.6300\n",
      "Accuracy: 0.5740\n",
      "[Epoch 10/200] Loss 1.2912, train acc 0.6797\n",
      "[Epoch 11/200] Loss 1.2239, train acc 0.7131\n",
      "[Epoch 12/200] Loss 1.1491, train acc 0.7559\n",
      "[Epoch 13/200] Loss 1.0590, train acc 0.7875\n",
      "[Epoch 14/200] Loss 1.0004, train acc 0.8138\n",
      "[Epoch 15/200] Loss 0.9266, train acc 0.8226\n",
      "[Epoch 16/200] Loss 0.8660, train acc 0.8296\n",
      "[Epoch 17/200] Loss 0.8002, train acc 0.8355\n",
      "[Epoch 18/200] Loss 0.7828, train acc 0.8331\n",
      "[Epoch 19/200] Loss 0.7468, train acc 0.8267\n",
      "Accuracy: 0.8240\n",
      "[Epoch 20/200] Loss 0.6775, train acc 0.8484\n",
      "[Epoch 21/200] Loss 0.6471, train acc 0.8431\n",
      "[Epoch 22/200] Loss 0.5792, train acc 0.8671\n",
      "[Epoch 23/200] Loss 0.5613, train acc 0.8671\n",
      "[Epoch 24/200] Loss 0.5230, train acc 0.8677\n",
      "[Epoch 25/200] Loss 0.5038, train acc 0.8770\n",
      "[Epoch 26/200] Loss 0.4938, train acc 0.8741\n",
      "[Epoch 27/200] Loss 0.4529, train acc 0.8823\n",
      "[Epoch 28/200] Loss 0.4469, train acc 0.8835\n",
      "[Epoch 29/200] Loss 0.4164, train acc 0.8893\n",
      "Accuracy: 0.8440\n",
      "[Epoch 30/200] Loss 0.3965, train acc 0.8970\n",
      "[Epoch 31/200] Loss 0.3995, train acc 0.8940\n",
      "[Epoch 32/200] Loss 0.3938, train acc 0.8946\n",
      "[Epoch 33/200] Loss 0.3910, train acc 0.8888\n",
      "[Epoch 34/200] Loss 0.3617, train acc 0.8940\n",
      "[Epoch 35/200] Loss 0.3730, train acc 0.8958\n",
      "[Epoch 36/200] Loss 0.3599, train acc 0.8929\n",
      "[Epoch 37/200] Loss 0.3287, train acc 0.9139\n",
      "[Epoch 38/200] Loss 0.3284, train acc 0.9098\n",
      "[Epoch 39/200] Loss 0.3431, train acc 0.9087\n",
      "Accuracy: 0.8580\n",
      "[Epoch 40/200] Loss 0.3086, train acc 0.9128\n",
      "[Epoch 41/200] Loss 0.3030, train acc 0.9133\n",
      "[Epoch 42/200] Loss 0.2969, train acc 0.9233\n",
      "[Epoch 43/200] Loss 0.3038, train acc 0.9075\n",
      "[Epoch 44/200] Loss 0.3018, train acc 0.9087\n",
      "[Epoch 45/200] Loss 0.2936, train acc 0.9151\n",
      "[Epoch 46/200] Loss 0.3024, train acc 0.9028\n",
      "[Epoch 47/200] Loss 0.2971, train acc 0.9145\n",
      "[Epoch 48/200] Loss 0.2743, train acc 0.9157\n",
      "[Epoch 49/200] Loss 0.2815, train acc 0.9151\n",
      "Accuracy: 0.8460\n",
      "[Epoch 50/200] Loss 0.2864, train acc 0.9151\n",
      "[Epoch 51/200] Loss 0.2846, train acc 0.9192\n",
      "[Epoch 52/200] Loss 0.2673, train acc 0.9233\n",
      "[Epoch 53/200] Loss 0.2743, train acc 0.9227\n",
      "[Epoch 54/200] Loss 0.2804, train acc 0.9116\n",
      "[Epoch 55/200] Loss 0.2727, train acc 0.9210\n",
      "[Epoch 56/200] Loss 0.2649, train acc 0.9210\n",
      "[Epoch 57/200] Loss 0.2797, train acc 0.9151\n",
      "[Epoch 58/200] Loss 0.2619, train acc 0.9268\n",
      "[Epoch 59/200] Loss 0.2557, train acc 0.9256\n",
      "Accuracy: 0.8480\n",
      "[Epoch 60/200] Loss 0.2606, train acc 0.9245\n",
      "[Epoch 61/200] Loss 0.2646, train acc 0.9198\n",
      "[Epoch 62/200] Loss 0.2621, train acc 0.9233\n",
      "[Epoch 63/200] Loss 0.2573, train acc 0.9251\n",
      "[Epoch 64/200] Loss 0.2417, train acc 0.9245\n",
      "[Epoch 65/200] Loss 0.2526, train acc 0.9297\n",
      "[Epoch 66/200] Loss 0.2518, train acc 0.9321\n",
      "[Epoch 67/200] Loss 0.2509, train acc 0.9286\n",
      "[Epoch 68/200] Loss 0.2432, train acc 0.9303\n",
      "[Epoch 69/200] Loss 0.2456, train acc 0.9286\n",
      "Accuracy: 0.8440\n",
      "[Epoch 70/200] Loss 0.2361, train acc 0.9268\n",
      "[Epoch 71/200] Loss 0.2518, train acc 0.9292\n",
      "[Epoch 72/200] Loss 0.2464, train acc 0.9245\n",
      "[Epoch 73/200] Loss 0.2495, train acc 0.9215\n",
      "[Epoch 74/200] Loss 0.2407, train acc 0.9262\n",
      "[Epoch 75/200] Loss 0.2364, train acc 0.9368\n",
      "[Epoch 76/200] Loss 0.2354, train acc 0.9303\n",
      "[Epoch 77/200] Loss 0.2380, train acc 0.9309\n",
      "[Epoch 78/200] Loss 0.2262, train acc 0.9338\n",
      "[Epoch 79/200] Loss 0.2221, train acc 0.9420\n",
      "Accuracy: 0.8380\n",
      "[Epoch 80/200] Loss 0.2389, train acc 0.9350\n",
      "[Epoch 81/200] Loss 0.2286, train acc 0.9362\n",
      "[Epoch 82/200] Loss 0.2223, train acc 0.9362\n",
      "[Epoch 83/200] Loss 0.2351, train acc 0.9374\n",
      "[Epoch 84/200] Loss 0.2247, train acc 0.9315\n",
      "[Epoch 85/200] Loss 0.2409, train acc 0.9315\n",
      "[Epoch 86/200] Loss 0.2198, train acc 0.9379\n",
      "[Epoch 87/200] Loss 0.2172, train acc 0.9385\n",
      "[Epoch 88/200] Loss 0.2196, train acc 0.9379\n",
      "[Epoch 89/200] Loss 0.2232, train acc 0.9327\n",
      "Accuracy: 0.8360\n",
      "[Epoch 90/200] Loss 0.2159, train acc 0.9356\n",
      "[Epoch 91/200] Loss 0.2290, train acc 0.9297\n",
      "[Epoch 92/200] Loss 0.2233, train acc 0.9385\n",
      "[Epoch 93/200] Loss 0.2180, train acc 0.9379\n",
      "[Epoch 94/200] Loss 0.2248, train acc 0.9338\n",
      "[Epoch 95/200] Loss 0.2292, train acc 0.9309\n",
      "[Epoch 96/200] Loss 0.2189, train acc 0.9368\n",
      "[Epoch 97/200] Loss 0.2125, train acc 0.9362\n",
      "[Epoch 98/200] Loss 0.2210, train acc 0.9362\n",
      "[Epoch 99/200] Loss 0.2301, train acc 0.9327\n",
      "Accuracy: 0.8340\n",
      "[Epoch 100/200] Loss 0.2169, train acc 0.9391\n",
      "[Epoch 101/200] Loss 0.2151, train acc 0.9374\n",
      "[Epoch 102/200] Loss 0.1990, train acc 0.9444\n",
      "[Epoch 103/200] Loss 0.2059, train acc 0.9356\n",
      "[Epoch 104/200] Loss 0.2099, train acc 0.9297\n",
      "[Epoch 105/200] Loss 0.2201, train acc 0.9321\n",
      "[Epoch 106/200] Loss 0.1994, train acc 0.9391\n",
      "[Epoch 107/200] Loss 0.2202, train acc 0.9379\n",
      "[Epoch 108/200] Loss 0.2065, train acc 0.9438\n",
      "[Epoch 109/200] Loss 0.2010, train acc 0.9379\n",
      "Accuracy: 0.8300\n",
      "[Epoch 110/200] Loss 0.2072, train acc 0.9420\n",
      "[Epoch 111/200] Loss 0.1912, train acc 0.9368\n",
      "[Epoch 112/200] Loss 0.2134, train acc 0.9327\n",
      "[Epoch 113/200] Loss 0.1926, train acc 0.9415\n",
      "[Epoch 114/200] Loss 0.2038, train acc 0.9420\n",
      "[Epoch 115/200] Loss 0.1940, train acc 0.9415\n",
      "[Epoch 116/200] Loss 0.2025, train acc 0.9356\n",
      "[Epoch 117/200] Loss 0.1924, train acc 0.9432\n",
      "[Epoch 118/200] Loss 0.1973, train acc 0.9374\n",
      "[Epoch 119/200] Loss 0.1891, train acc 0.9456\n",
      "Accuracy: 0.8260\n",
      "[Epoch 120/200] Loss 0.1884, train acc 0.9432\n",
      "[Epoch 121/200] Loss 0.1977, train acc 0.9379\n",
      "[Epoch 122/200] Loss 0.1914, train acc 0.9403\n",
      "[Epoch 123/200] Loss 0.1812, train acc 0.9508\n",
      "[Epoch 124/200] Loss 0.2056, train acc 0.9397\n",
      "[Epoch 125/200] Loss 0.1940, train acc 0.9391\n",
      "[Epoch 126/200] Loss 0.1797, train acc 0.9473\n",
      "[Epoch 127/200] Loss 0.1789, train acc 0.9502\n",
      "[Epoch 128/200] Loss 0.1965, train acc 0.9409\n",
      "[Epoch 129/200] Loss 0.2014, train acc 0.9450\n",
      "Accuracy: 0.8300\n",
      "[Epoch 130/200] Loss 0.2039, train acc 0.9438\n",
      "[Epoch 131/200] Loss 0.1953, train acc 0.9409\n",
      "[Epoch 132/200] Loss 0.1905, train acc 0.9415\n",
      "[Epoch 133/200] Loss 0.1856, train acc 0.9467\n",
      "[Epoch 134/200] Loss 0.1807, train acc 0.9461\n",
      "[Epoch 135/200] Loss 0.1907, train acc 0.9438\n",
      "[Epoch 136/200] Loss 0.1807, train acc 0.9420\n",
      "[Epoch 137/200] Loss 0.1926, train acc 0.9391\n",
      "[Epoch 138/200] Loss 0.1937, train acc 0.9385\n",
      "[Epoch 139/200] Loss 0.1885, train acc 0.9438\n",
      "Accuracy: 0.8260\n",
      "[Epoch 140/200] Loss 0.1940, train acc 0.9456\n",
      "[Epoch 141/200] Loss 0.1812, train acc 0.9450\n",
      "[Epoch 142/200] Loss 0.1844, train acc 0.9473\n",
      "[Epoch 143/200] Loss 0.1870, train acc 0.9403\n",
      "[Epoch 144/200] Loss 0.1941, train acc 0.9385\n",
      "[Epoch 145/200] Loss 0.1969, train acc 0.9403\n",
      "[Epoch 146/200] Loss 0.1903, train acc 0.9479\n",
      "[Epoch 147/200] Loss 0.1852, train acc 0.9467\n",
      "[Epoch 148/200] Loss 0.1906, train acc 0.9350\n",
      "[Epoch 149/200] Loss 0.1770, train acc 0.9397\n",
      "Accuracy: 0.8280\n",
      "[Epoch 150/200] Loss 0.1751, train acc 0.9444\n",
      "[Epoch 151/200] Loss 0.1827, train acc 0.9473\n",
      "[Epoch 152/200] Loss 0.1814, train acc 0.9420\n",
      "[Epoch 153/200] Loss 0.1829, train acc 0.9438\n",
      "[Epoch 154/200] Loss 0.1822, train acc 0.9426\n",
      "[Epoch 155/200] Loss 0.1774, train acc 0.9485\n",
      "[Epoch 156/200] Loss 0.1789, train acc 0.9432\n",
      "[Epoch 157/200] Loss 0.1737, train acc 0.9508\n",
      "[Epoch 158/200] Loss 0.1872, train acc 0.9403\n",
      "[Epoch 159/200] Loss 0.1682, train acc 0.9485\n",
      "Accuracy: 0.8200\n",
      "[Epoch 160/200] Loss 0.1782, train acc 0.9537\n",
      "[Epoch 161/200] Loss 0.1800, train acc 0.9444\n",
      "[Epoch 162/200] Loss 0.1784, train acc 0.9415\n",
      "[Epoch 163/200] Loss 0.1902, train acc 0.9415\n",
      "[Epoch 164/200] Loss 0.1768, train acc 0.9537\n",
      "[Epoch 165/200] Loss 0.1725, train acc 0.9444\n",
      "[Epoch 166/200] Loss 0.1751, train acc 0.9467\n",
      "[Epoch 167/200] Loss 0.1705, train acc 0.9491\n",
      "[Epoch 168/200] Loss 0.1854, train acc 0.9467\n",
      "[Epoch 169/200] Loss 0.1759, train acc 0.9461\n",
      "Accuracy: 0.8180\n",
      "[Epoch 170/200] Loss 0.1691, train acc 0.9491\n",
      "[Epoch 171/200] Loss 0.1688, train acc 0.9479\n",
      "[Epoch 172/200] Loss 0.1808, train acc 0.9415\n",
      "[Epoch 173/200] Loss 0.1762, train acc 0.9467\n",
      "[Epoch 174/200] Loss 0.1745, train acc 0.9432\n",
      "[Epoch 175/200] Loss 0.1759, train acc 0.9479\n",
      "[Epoch 176/200] Loss 0.1716, train acc 0.9514\n",
      "[Epoch 177/200] Loss 0.1729, train acc 0.9467\n",
      "[Epoch 178/200] Loss 0.1793, train acc 0.9432\n",
      "[Epoch 179/200] Loss 0.1729, train acc 0.9485\n",
      "Accuracy: 0.8160\n",
      "[Epoch 180/200] Loss 0.1730, train acc 0.9537\n",
      "[Epoch 181/200] Loss 0.1729, train acc 0.9549\n",
      "[Epoch 182/200] Loss 0.1604, train acc 0.9561\n",
      "[Epoch 183/200] Loss 0.1771, train acc 0.9456\n",
      "[Epoch 184/200] Loss 0.1591, train acc 0.9549\n",
      "[Epoch 185/200] Loss 0.1716, train acc 0.9456\n",
      "[Epoch 186/200] Loss 0.1666, train acc 0.9456\n",
      "[Epoch 187/200] Loss 0.1699, train acc 0.9526\n",
      "[Epoch 188/200] Loss 0.1717, train acc 0.9432\n",
      "[Epoch 189/200] Loss 0.1684, train acc 0.9485\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.8180\n",
      "[Epoch 190/200] Loss 0.1786, train acc 0.9450\n",
      "[Epoch 191/200] Loss 0.1624, train acc 0.9508\n",
      "[Epoch 192/200] Loss 0.1647, train acc 0.9532\n",
      "[Epoch 193/200] Loss 0.1694, train acc 0.9479\n",
      "[Epoch 194/200] Loss 0.1665, train acc 0.9479\n",
      "[Epoch 195/200] Loss 0.1660, train acc 0.9467\n",
      "[Epoch 196/200] Loss 0.1641, train acc 0.9473\n",
      "[Epoch 197/200] Loss 0.1598, train acc 0.9549\n",
      "[Epoch 198/200] Loss 0.1628, train acc 0.9491\n",
      "[Epoch 199/200] Loss 0.1727, train acc 0.9461\n",
      "Accuracy: 0.8180\n"
     ]
    }
   ],
   "source": [
    "'''\n",
    "开始训练模型\n",
    "'''\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = Net(feat_dim=data.num_node_features, num_class=7).to(device)         # Initialize model\n",
    "data = data.to(device)                                                       \n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) # Initialize optimizer and training params\n",
    "\n",
    "for epoch in range(200):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    # Get output\n",
    "    out = model(data)\n",
    "    \n",
    "    # Get loss\n",
    "    loss = F.nll_loss(out[data.train_mask.bool()], data.y[data.train_mask.bool()].long())\n",
    "    _, pred = out.max(dim=1)\n",
    "    \n",
    "    # Get predictions and calculate training accuracy\n",
    "    correct = float(torch.masked_select(pred, data.train_mask.bool()).eq(torch.masked_select(data.y, data.train_mask.bool())).sum().item())\n",
    "    acc = correct / data.train_mask.sum().item()\n",
    "    print('[Epoch {}/200] Loss {:.4f}, train acc {:.4f}'.format(epoch, loss.cpu().detach().data.item(), acc))\n",
    "    \n",
    "    # Backward\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    \n",
    "    # Evaluation on test data every 10 epochs\n",
    "    if (epoch+1) % 10 == 0:\n",
    "        model.eval()\n",
    "        _, pred = model(data).max(dim=1)\n",
    "        correct = float(pred[data.test_mask.bool()].eq(data.y[data.test_mask.bool()]).sum().item())\n",
    "        acc = correct / data.test_mask.sum().item()\n",
    "        print('Accuracy: {:.4f}'.format(acc))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
